import torch
import torch.nn as nn

'''
Define Model Class
'''

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)    

def model_gen(model_type):
    if model_type == '2dnn':
            return nn.Sequential(Flatten(), nn.Linear(784,100), nn.ReLU(), nn.Linear(100,1))
    elif model_type == '4dnn':
            return nn.Sequential(Flatten(), nn.Linear(784,100), nn.ReLU(), 
                                 nn.Linear(100,100), nn.ReLU(),
                                 nn.Linear(100,100), nn.ReLU(),
                                 nn.Linear(100,1))